#!/usr/bin/env Rscript

rm(list=ls())
library(FactorHet)
library(dplyr) 
library(tidyr)
library(stringr)
library(cjbart)
library(reshape2)
library(mvtnorm)

dir.create('out')

args <- commandArgs(trailingOnly = TRUE)
slurm_id <- as.numeric(args[1])
NSIMS <- as.numeric(args[2])

bs_truth <- readRDS('bs_true_values.RDS')
true_AME <- bs_truth$true_AME
true_population <- bs_truth$true_population
true_beta <- bs_truth$true_beta
true_phi <- bs_truth$true_phi
NCLUSTER <- bs_truth$NCLUSTER
NJ <- bs_truth$NJ
NL <- bs_truth$NL
NMOD <- bs_truth$NMOD

for (sim in 1:NSIMS){
  time <- Sys.time()
  current_dir <- getwd()
  tdir <- tempdir()
  for (size in c('small', 'large', 'very_large')){
   
   message(size)
   setwd(tdir)

   seed <- 100 * slurm_id + sim + 10^6 * (size == 'large') + 10^6 * 2 * (size == 'very_large')
   set.seed(seed)
   
   if (size == 'very_large'){
     NPEOPLE <- 4000
     NTRIALS <- 10
   }else if (size == 'large'){
     NPEOPLE <- 2000
     NTRIALS <- 10
   }else if (size == 'small'){
     NPEOPLE <- 1000
     NTRIALS <- 5
   }else{stop('Invalid Size')}
   
   message(paste0('Prepare: ', sim, '-', size))
   
   sim_moderator <- cbind(1, mvtnorm::rmvnorm(n = NPEOPLE, sigma = toeplitz(0.25^(0:(NMOD-1)))))
   sim_pi <- FactorHet:::softmax_matrix(sim_moderator %*% true_phi)
   
   sim_cluster <- apply(sim_pi, MARGIN = 1, FUN=function(i){sample(1:NCLUSTER, 1, prob = i)})
   
   sim_data <- dplyr::bind_rows(lapply(seq_len(NPEOPLE), FUN=function(i){
     cluster_i <- sim_cluster[i]
     data_i <- dplyr::bind_rows(lapply(1:NTRIALS, FUN=function(trial){
       data_j <- (sapply(1:NJ, FUN=function(j){sample(1:NL, 2, replace = T)}))
       
       latent_j <- apply(data_j, MARGIN = 1, FUN=function(k){
         sum(Matrix::sparseMatrix(i = 1:NJ, j = k, x = 1, dims = c(NJ, NL)) * true_beta[[cluster_i]])
       })
       outcome_r <- rbinom(n = 1, size = 1, prob = plogis(diff(latent_j)))
       data_j <- data.frame(t(apply(data_j, MARGIN = 1, FUN=function(k){paste0('f', 1:NJ, '_', letters[k])})),
                            stringsAsFactors = F)
       data_j$outcome <- c(1 - outcome_r, outcome_r)
       data_j <- data.frame(data_j, stringsAsFactors = F)
       data_j$choice <- c('left', 'right')
       return(data_j)
     }), .id = 'trial')
     return(data_i)
   }), .id = 'person')
   
   sim_data <- dplyr::left_join(sim_data, 
     cbind(as.data.frame(sim_moderator[,-1]), person = as.character(1:NPEOPLE)) ,
     by = c('person')
   )
   
   message(paste0('Standard: ', sim, '-', size))
   
   est_FH <- FactorHet_mbo(
     formula = outcome ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
     design = sim_data, K = 3,
     moderator = ~ V1 + V2 + V3 + V4 + V5,
     group = ~ person, task = ~ trial, choice_order = ~ choice
   )
   refit_FH <- FactorHet_refit(object = est_FH,
      newdata = sim_data)
   
   message(paste0('Sample Split: ', sim, '-', size))
   # Select half of the people and then fit the model.... and refit...
   uperson <- unique(sim_data$person)
   uid <- sample(uperson, NPEOPLE/2)
   
   est_FH_half <- FactorHet_mbo(
     formula = outcome ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
     design = sim_data %>% filter(person %in% uid), K = 3,
     moderator = ~ V1 + V2 + V3 + V4 + V5,
     group = ~ person, task = ~ trial, choice_order = ~ choice
   )
   est_FH_ssplit <- FactorHet_refit(object = est_FH_half, 
      newdata = sim_data %>% filter(!(person %in% uid)))

   message(paste0('Process: ', sim, '-', size))
   
   # Get the AME from the model fit on the whole data
   nonadapt_AME_bs <- AME(est_FH, verbose = F, plot = FALSE)
   nonadapt_cjoint_bs <- cjoint_plot(est_FH, plot = FALSE)$data
   nonadapt_MFX_bs <- margeff_moderators(est_FH)$data
   # Get the AME from the refit model (i.e. assume sparsity pattern)
   # is correct and then refit
   refit_AME_bs <- AME(refit_FH, verbose = F, plot = FALSE)
   refit_cjoint_bs <- cjoint_plot(refit_FH, plot = FALSE)$data
   # Get the AME from a model fit on half the data as "check"
   half_AME_bs <- AME(est_FH_half, 
      plot = FALSE, design = sim_data, verbose = FALSE)
   half_cjoint_bs <- cjoint_plot(est_FH_half, plot = FALSE)$data
   # Get the AME from a model *refit* on half the data
   ssplit_AME_bs <- AME(est_FH_ssplit, 
      plot = FALSE, design = sim_data, verbose = FALSE)
   ssplit_cjoint_bs <- cjoint_plot(est_FH_ssplit, plot = FALSE)$data
      
   # Get CAMCE for each individual and cluster membership probabilities
   unique_data <- sim_data %>% select(person, matches('^V[0-9]+')) %>% unique
   nonadapt_HTE <- HTE_by_individual(
    object = est_FH, 
    AME = nonadapt_AME_bs, 
    design = unique_data)
   nonadapt_group_pred <- predict(est_FH, newdata = sim_data, return = 'detailed')
   nonadapt_group_pred <- data.frame(
      group = nonadapt_group_pred$unique_new_group,
      nonadapt_group_pred$posterior_predictive_group
   )
   nonadapt_HTE$posterior_predictive <- nonadapt_group_pred

   refit_HTE <- HTE_by_individual(
    object = refit_FH, 
    AME = refit_AME_bs, 
    design = unique_data)
   # Cannot estimate posterior predictive easily with "refit" so use
   # the model that has the same phi
   refit_group_pred <- predict(est_FH, newdata = sim_data, return = 'detailed')
   refit_group_pred <- data.frame(
      group = refit_group_pred$unique_new_group,
      refit_group_pred$posterior_predictive_group
   )
   refit_HTE$posterior_predictive <- refit_group_pred

   half_HTE <- HTE_by_individual(
    object = est_FH_half, 
    AME = half_AME_bs, 
    design = unique_data)
   half_group_pred <- predict(est_FH_half, 
      newdata = sim_data, return = 'detailed')
   half_group_pred <- data.frame(
      group = half_group_pred$unique_new_group,
      half_group_pred$posterior_predictive_group
   )
   half_HTE$posterior_predictive <- half_group_pred

   ssplit_HTE <- HTE_by_individual(
    object = est_FH_ssplit, 
    AME = ssplit_AME_bs, 
    design = unique_data)
   ssplit_group_pred <- predict(est_FH_half, newdata = sim_data, return = 'detailed')
   ssplit_group_pred <- data.frame(
      group = ssplit_group_pred$unique_new_group,
      ssplit_group_pred$posterior_predictive_group
   )
   ssplit_HTE$posterior_predictive <- ssplit_group_pred

   message(paste0('cjbart: ', sim, '-', size))

   est_cjbart <- cjbart(
    data = sim_data %>% 
      select(outcome, trial, person, matches('^[VX][0-9]+')),
    Y = 'outcome',
    id = 'person',
    round = 'trial', cores = 1, type = 'choice'
   )

   cjbart_HTE <- IMCE(
    model = est_cjbart,
    attribs = paste0('X', 1:10), 
    ref_levels = paste0('f',1:10, '_a'),
    data = sim_data %>% 
      select(outcome, trial, person, matches('^[VX][0-9]+'))
   )

   cjbart_lower <- cjbart_HTE$imce_lower
   cjbart_higher <- cjbart_HTE$imce_upper

   cjbart_HTE <- cjbart_HTE$imce %>% 
      select(person, matches('^f[0-9]')) %>% unique
   if (nrow(cjbart_HTE) != NPEOPLE){
      stop('cjbart returned odd results')
   }
   cjbart_HTE <- cjbart_HTE %>% 
      pivot_longer(cols = -person) %>% 
      rename(level = name, est = value, group = person)
   cjbart_HTE <- cjbart_HTE %>% 
      mutate(factor = paste0('X', str_extract(level, pattern='[0-9]+'))) %>%
      relocate(group, factor, level, est)
   cjbart_AMCE <- cjbart_HTE %>% 
      group_by(factor, level) %>% 
      summarize(AMCE = mean(est))

   cjbart_lower <- cjbart_lower %>% 
      pivot_longer(cols = -person) %>% 
      rename(level = name, ll = value, group = person)
   cjbart_higher <- cjbart_higher %>% 
      pivot_longer(cols = -person) %>% 
      rename(level = name, ul = value, group = person)
   cjbart_HTE <- full_join(cjbart_HTE, cjbart_lower, by = c('group', 'level')) %>% 
      full_join(cjbart_higher, by = c('group', 'level'))

   wide_AME <- true_AME %>%
      reshape2::dcast(factor + level ~ cluster, value.var = 'mean')
   true_HTE <- reshape2::melt(sim_pi %*% t(as.matrix(wide_AME[,-1:-2])))
   wide_AME$id <- 1:nrow(wide_AME)
   true_HTE <- left_join(
    true_HTE,
    wide_AME %>% select(id, factor, level),
    by = c('Var2' = 'id'))
   true_HTE <- true_HTE %>% rename(id = Var1, truth = value) %>%
    select(-Var2)
   true_HTE <- true_HTE %>%
      mutate(level = paste0('f', factor, '_', letters[true_HTE$level])) %>%
      mutate(factor = paste0('X', factor))
   
   if (nrow(cjbart_HTE) != nrow(nonadapt_HTE$individual)){
    stop('cjbart alignment error')
   } 

   message(paste0('wrong K: ', sim, '-', size))
   wrong_K_loop <- list()
   out_of_sample_accuracy <- list()

   for (K_loop in c(1, 2, 4)){
      message(paste0('Wrong K:', K_loop))
      FH_wrong_K <- suppressMessages(FactorHet_mbo(
        formula = outcome ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
        design = sim_data, K = K_loop,
        moderator = ~ V1 + V2 + V3 + V4 + V5,
        group = ~ person, task = ~ trial, choice_order = ~ choice
      ))

      half_FH_wrong_K <- suppressMessages(FactorHet_mbo(
        formula = outcome ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
        design = sim_data %>% filter(person %in% uid), K = K_loop,
        moderator = ~ V1 + V2 + V3 + V4 + V5,
        group = ~ person, task = ~ trial, choice_order = ~ choice
      ))

      AME_wrong_K <- AME(FH_wrong_K, verbose = FALSE, plot = FALSE)

      ssplit_wrong_K <- FactorHet_refit(
         object = half_FH_wrong_K, 
         newdata = sim_data %>% filter(!(person %in% uid)))

      ssplit_AME_wrong_K <- AME(ssplit_wrong_K, 
         plot = FALSE, design = sim_data, verbose = FALSE)

      pred_oos <- predict(
         half_FH_wrong_K,
         newdata = sim_data %>% filter(!(person %in% uid))
      )
      pred_oos <- cbind(
         sim_data %>% filter(!(person %in% uid)) %>% select(outcome, person, trial, choice),
         prediction = pred_oos
      )
      pred_oos$BIC <- BIC(FH_wrong_K)
      pred_oos$BIC_half <- BIC(half_FH_wrong_K)
      out_of_sample_accuracy[[K_loop]] <- pred_oos

      if (K_loop == 1){

         wrong_K_HTE <- AME_wrong_K$data %>%
            filter(!baseline) %>%
            select(est = marginal_effect, var, factor, level)
         wrong_K_HTE <- list(
            'individual' = crossing(
             data.frame(group = unique(sim_data$person)),
             wrong_K_HTE),
            'population' = wrong_K_HTE
         )

         ssplit_wrong_K_HTE <- ssplit_AME_wrong_K$data %>%
            filter(!baseline) %>%
            select(est = marginal_effect, var, factor, level)
         ssplit_wrong_K_HTE <- list(
            'individual' = crossing(
             data.frame(group = unique(sim_data$person)),
             ssplit_wrong_K_HTE),
            'population' = ssplit_wrong_K_HTE
         )

      }else{

         wrong_K_HTE <- HTE_by_individual(
          object = FH_wrong_K, 
          AME = AME_wrong_K,
          design = unique_data)
         wrong_K_group_pred <- predict(FH_wrong_K, newdata = sim_data, return = 'detailed')
         wrong_K_group_pred <- data.frame(
            group = wrong_K_group_pred$unique_new_group,
            wrong_K_group_pred$posterior_predictive_group
         )
         wrong_K_HTE$posterior_predictive <- wrong_K_group_pred
         stopifnot(ncol(wrong_K_HTE$posterior_predictive) == K_loop + 1)

         ssplit_wrong_K_HTE <- HTE_by_individual(
          object = ssplit_wrong_K, 
          AME = ssplit_AME_wrong_K,
          design = unique_data)
         ssplit_wrong_K_group_pred <- predict(half_FH_wrong_K, newdata = sim_data, return = 'detailed')
         ssplit_wrong_K_group_pred <- data.frame(
            group = ssplit_wrong_K_group_pred$unique_new_group,
            ssplit_wrong_K_group_pred$posterior_predictive_group
         )
         ssplit_wrong_K_HTE$posterior_predictive <- ssplit_wrong_K_group_pred
         stopifnot(ncol(wrong_K_HTE$posterior_predictive) == K_loop + 1)

      }

      wrong_K_loop[[K_loop]] <- list(
         K = K_loop, check_K = ncol(coef(FH_wrong_K)),
         full_HTE = wrong_K_HTE, 
         ssplit_HTE = ssplit_wrong_K_HTE)
      rm(FH_wrong_K, half_FH_wrong_K, 
         ssplit_wrong_K, AME_wrong_K, pred_oos)
   }

   # Get out of sample predictive accuracy for K=3
   pred_oos <- predict(
      est_FH_half,
      newdata = sim_data %>% filter(!(person %in% uid))
   )
   pred_oos <- cbind(
      sim_data %>% filter(!(person %in% uid)) %>% select(outcome, person, trial, choice),
      prediction = pred_oos
   )
   pred_oos$BIC <- BIC(est_FH)
   pred_oos$BIC_half <- BIC(est_FH_half)
   out_of_sample_accuracy[[3]] <- pred_oos
   rm(pred_oos)
   # Summarize out of sample accuracy using half-CV
   # as well as the BIC
   out_of_sample_accuracy <- bind_rows(out_of_sample_accuracy, .id = 'K') %>% 
      mutate(ll = ifelse(outcome == 0, log(1-prediction), log(prediction))) %>%
      group_by(K) %>% 
      summarize(BIC = mean(BIC), BIC_half = mean(BIC_half),
         ll = mean(ll), rmse = sqrt(mean( (outcome - prediction)^2 ))) %>%
      ungroup

   # Estimate models with mis-specified moderators
   misspec_loop <- list()
   for (misspec_type in c('no_mod', 'nonlinear_mod')){
      print(misspec_type)
      
      if (misspec_type == 'no_mod'){
         misspec_fmla <- NULL
         do_misspec_hte <- FALSE
      }else if (misspec_type == 'nonlinear_mod'){
         # Create "bad" moderators that are non-linear transformations
         sim_data$B1 <- sqrt(3) * exp(sim_data$V1/2) - 2
         sim_data$B2 <- sqrt(3) * sim_data$V2/(1+exp(sim_data$V1))
         sim_data$B3 <- 1/19 * (sim_data$V1 * sim_data$V3 + 0.6)^3
         sim_data$B4 <- 1/3 * (sim_data$V2 + sim_data$V4)^2 - 1
         sim_data$B5 <- 2.5 * sqrt(abs(sim_data$V5 * sim_data$V1)) - 2.5
         # Standardize covariates to be mean-zero unit variance
         sim_data <- sim_data %>% 
            ungroup %>% 
            mutate(across(matches('^B[0-9]'), ~ as.vector(scale(.))))
         misspec_fmla <- ~ B1 + B2 + B3 + B4 + B5
         do_misspec_hte <- TRUE
      }else{
         stop('Invalid misspec_type')         
      }

      est_FH_misspec <- FactorHet_mbo(
        formula = outcome ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
        design = sim_data, K = 3,
        moderator = misspec_fmla,
        group = ~ person, task = ~ trial, choice_order = ~ choice
      )
      refit_FH_misspec <- FactorHet_refit(object = est_FH_misspec, 
         newdata = sim_data)
      est_FH_half_misspec <- FactorHet_mbo(
        formula = outcome ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
        moderator = misspec_fmla,
        design = sim_data %>% filter(person %in% uid), K = 3,
        group = ~ person, task = ~ trial, choice_order = ~ choice
      )
      est_FH_ssplit_misspec <- FactorHet_refit(object = est_FH_half_misspec, 
         newdata = sim_data %>% filter(!(person %in% uid)))
      # Get AME under moderator mis-specification
      nonadapt_AME_bs_misspec <- AME(est_FH_misspec, verbose = F, plot = FALSE)
      half_AME_bs_misspec <- AME(est_FH_half_misspec, verbose = F, plot = FALSE)
      refit_AME_bs_misspec <- AME(refit_FH_misspec, verbose = F, plot = FALSE)
      ssplit_AME_bs_misspec <- AME(est_FH_ssplit_misspec, verbose = F, plot = FALSE)

      nonadapt_misspec <- list(
         AME = nonadapt_AME_bs_misspec$data,
         cjoint = cjoint_plot(est_FH_misspec, plot = FALSE)$data,
         posterior_predictive = est_FH_misspec$posterior$posterior_predictive,
         posterior = est_FH_misspec$posterior$posterior
      )
      half_misspec <- list(
         AME = half_AME_bs_misspec$data,
         cjoint = cjoint_plot(est_FH_half_misspec, plot = FALSE)$data,
         posterior_predictive = est_FH_half_misspec$posterior$posterior_predictive,
         posterior = est_FH_half_misspec$posterior$posterior
      )
      refit_misspec <- list(
         AME = refit_AME_bs_misspec$data,
         cjoint = cjoint_plot(refit_FH_misspec, plot = FALSE)$data,
         posterior_predictive = est_FH_misspec$posterior$posterior_predictive,
         posterior = est_FH_misspec$posterior$posterior
      )
      ssplit_misspec <- list(
         AME = ssplit_AME_bs_misspec$data,
         cjoint = cjoint_plot(est_FH_ssplit_misspec, plot = FALSE)$data,
         posterior_predictive = est_FH_half_misspec$posterior$posterior_predictive,
         posterior = est_FH_half_misspec$posterior$posterior
      )

      if (do_misspec_hte){

         nonadapt_misspec$HTE <- HTE_by_individual(
             object = est_FH_misspec, 
             AME = nonadapt_AME_bs_misspec, 
             design = sim_data %>% select(person, matches('^B[0-9]+')) %>% unique)
         half_misspec$HTE <- HTE_by_individual(
             object = est_FH_half_misspec, 
             AME = half_AME_bs_misspec, 
             design = sim_data %>% select(person, matches('^B[0-9]+')) %>% unique)
         refit_misspec$HTE <- HTE_by_individual(
             object = refit_FH_misspec, 
             AME = refit_AME_bs_misspec, 
             design = sim_data %>% select(person, matches('^B[0-9]+')) %>% unique)
         ssplit_misspec$HTE <- HTE_by_individual(
             object = est_FH_ssplit_misspec, 
             AME = ssplit_AME_bs_misspec, 
             design = sim_data %>% select(person, matches('^B[0-9]+')) %>% unique)

      }

      misspec_loop[[misspec_type]] <- list(
         'non_adapt' = nonadapt_misspec,
         'half' = half_misspec,
         'refit' = refit_misspec,
         'ssplit' = ssplit_misspec
      )
   }

   message('Collecting')
   output_bs <- list(
     'seed' = seed,
     'OOS_accuracy' = out_of_sample_accuracy,
     'wrong_K' = wrong_K_loop,
     'true_cluster' = sim_cluster,
     'true_HTE' = true_HTE,
     # Additional Analyses
     'misspec' = misspec_loop,
     'cjbart' = list(HTE = cjbart_HTE),
     # Primary Results
     'non_adapt' = list(AME = nonadapt_AME_bs$data,
        HTE = nonadapt_HTE,
        MFX = nonadapt_MFX_bs, 
        cjoint = nonadapt_cjoint_bs,
        posterior_predictive = est_FH$posterior$posterior_predictive,
        posterior = est_FH$posterior$posterior),
     'refit' = list(AME = refit_AME_bs$data,
         cjoint = refit_cjoint_bs,
         HTE = refit_HTE,
         posterior_predictive = est_FH$posterior$posterior_predictive,
         posterior = est_FH$posterior$posterior),
     'half' = list(AME = half_AME_bs$data,
        cjoint = half_cjoint_bs,
        HTE = half_HTE,
        posterior_predictive = est_FH_half$posterior$posterior_predictive,
        posterior = est_FH_half$posterior$posterior),
     'ssplit' = list(AME = ssplit_AME_bs$data,
       cjoint = ssplit_cjoint_bs,
       HTE = ssplit_HTE,
       posterior_predictive = est_FH_half$posterior$posterior_predictive,
       posterior = est_FH_half$posterior$posterior)
     )
   print('Intermediate time')
   print(time - Sys.time())
   message(paste0('Saving: ', sim, '-', size))
   setwd(current_dir)
   saveRDS(output_bs, paste0('out/sim_paper', slurm_id, '_', sim, '_', size, '.RDS'))
  }
  print('Full time')
  print(time - Sys.time())
}

# Save and package output for transfer using OSG
system('ls out -lsta', intern = TRUE)
system("tar -czvf out.tar.gz out/")
system("ls")
system("rm -r out")
